#!/usr/bin/env python

import argparse
import os

import numpy as np
import pandas as pd


def build_parser() -> argparse.ArgumentParser:
    ap = argparse.ArgumentParser(
        description=(
            "Summarize per-stack health for T3 (S/N, window width, flatness) "
            "by mass bin and size bin."
        )
    )
    ap.add_argument(
        "--stack-health",
        default="outputs/stack_health.csv",
        help="Path to stack_health.csv (default: outputs/stack_health.csv)",
    )
    ap.add_argument(
        "--mass-col",
        default=None,
        help=(
            "Name of the mass-bin column. If not given, will try common names "
            "like 'Mstar_bin' or 'mass_bin'."
        ),
    )
    ap.add_argument(
        "--size-col",
        default="R_G_bin",
        help="Name of the size-bin column (default: R_G_bin).",
    )
    ap.add_argument(
        "--out-mass-summary",
        default="outputs/stack_health_summary_by_mass.csv",
        help="CSV for mass-bin summary stats.",
    )
    ap.add_argument(
        "--out-mass-size-summary",
        default="outputs/stack_health_summary_by_mass_size.csv",
        help="CSV for (mass,size)-bin summary stats.",
    )
    return ap


def guess_mass_col(df: pd.DataFrame) -> str:
    """Pick a likely mass-bin column if user didn't specify one."""
    candidates = ["Mstar_bin", "mass_bin", "M_bin", "Mstar_bin_label"]
    for c in candidates:
        if c in df.columns:
            return c
    raise SystemExit(
        "ERROR: Could not guess mass-bin column. "
        "Please re-run with --mass-col set to one of your mass bin columns."
    )


def main():
    ap = build_parser()
    args = ap.parse_args()

    if not os.path.exists(args.stack_health):
        raise SystemExit(
            f"ERROR: stack health CSV not found at {args.stack_health}.\n"
            f"Run scripts/make_stack_health_table.py first."
        )

    df = pd.read_csv(args.stack_health)

    # Figure out which column is the mass bin
    mass_col = args.mass_col or guess_mass_col(df)
    if mass_col not in df.columns:
        raise SystemExit(f"ERROR: mass column '{mass_col}' not found in stack_health.csv.")

    size_col = args.size_col
    if size_col not in df.columns:
        print(f"[warn] size column '{size_col}' not found; size-bin summaries may be limited.")

    # Ensure RG_mid_bin exists if possible
    if "RG_mid_bin" not in df.columns and size_col in df.columns:
        # try to parse it like in make_stack_health_table.py (fallback)
        from math import isnan

        def safe_parse_mid(label):
            if not isinstance(label, str):
                return np.nan
            s = label.strip()
            if not s:
                return np.nan
            for ch in "[]()":
                s = s.replace(ch, "")
            parts = s.split(",")
            if len(parts) != 2:
                try:
                    return float(s)
                except Exception:
                    return np.nan
            try:
                a = float(parts[0])
                b = float(parts[1])
                return 0.5 * (a + b)
            except Exception:
                return np.nan

        df["RG_mid_bin"] = df[size_col].apply(safe_parse_mid)

    # --- Mass-bin summary ---
    group_mass = df.groupby(mass_col, dropna=False)

    mass_summary = group_mass.agg(
        n_stacks=("stack_id", "count"),
        median_RG_mid=("RG_mid_bin", "median"),
        min_RG_mid=("RG_mid_bin", "min"),
        max_RG_mid=("RG_mid_bin", "max"),
        median_A_theta=("A_theta", "median"),
        median_A_theta_SNR=("A_theta_SNR", "median"),
        min_A_theta_SNR=("A_theta_SNR", "min"),
        median_win_nbins=("win_nbins", "median"),
        min_win_nbins=("win_nbins", "min"),
        median_rmse_flat=("rmse_flat", "median"),
        median_R2_flat=("R2_flat", "median"),
    ).reset_index()

    # --- Mass+size-bin summary ---
    if size_col in df.columns:
        group_ms = df.groupby([mass_col, size_col], dropna=False)
        mass_size_summary = group_ms.agg(
            n_stacks=("stack_id", "count"),
            RG_mid_bin=("RG_mid_bin", "median"),
            A_theta=("A_theta", "median"),
            A_theta_SNR=("A_theta_SNR", "median"),
            win_nbins=("win_nbins", "median"),
            rmse_flat=("rmse_flat", "median"),
            R2_flat=("R2_flat", "median"),
        ).reset_index()
    else:
        mass_size_summary = pd.DataFrame()

    # --- Write outputs ---
    out_dir1 = os.path.dirname(args.out_mass_summary)
    if out_dir1:
        os.makedirs(out_dir1, exist_ok=True)
    mass_summary.to_csv(args.out_mass_summary, index=False)

    out_dir2 = os.path.dirname(args.out_mass_size_summary)
    if out_dir2:
        os.makedirs(out_dir2, exist_ok=True)
    if not mass_size_summary.empty:
        mass_size_summary.to_csv(args.out_mass_size_summary, index=False)

    # --- Print a readable summary ---
    print("\n=== Mass-bin summary ===\n")
    print(mass_summary.to_string(index=False))

    if not mass_size_summary.empty:
        print("\n=== Mass+size-bin summary ===\n")
        # Sort for nicer output
        mass_size_sorted = mass_size_summary.sort_values(
            by=[mass_col, "RG_mid_bin"], ascending=[True, True]
        )
        print(mass_size_sorted.to_string(index=False))
    else:
        print("\n[info] No size-bin column; mass+size summary skipped.")


if __name__ == "__main__":
    main()
